import subprocess
import argparse
import re
import os
import time
import logging

class OptimalBS(object):
    '''Search for optimal batch size on a specified model file'''
    def __init__(self, main_args):
        '''init func'''
        self.model_file = main_args.model_file
        self.min_bs = main_args.min_bs
        self.max_bs = main_args.max_bs
        self.iterations = main_args.iterations
        self.dummy_use = main_args.dummy_data_use
        current_time = time.strftime("%Y%m%d%H%M%S")
        logging.basicConfig(filename = 'result-optimal_batch_size-{}.log'.format(current_time),level = logging.INFO)
        self.check_parameters()

    def check_parameters(self):
        '''check if parameters available'''
        if not os.path.isfile(self.model_file):
            logging.exception('model file {} does not exist, please specify it.'.format(model_file))
        if self.min_bs > self.max_bs:
            logging.exception('max bs: {} smaller than min bs: {}.'.format(max_bs, min_bs))
        
    def _gen_new_model_file_name(self, new_bs):
        '''obtain new model file name'''
        model_path = os.path.dirname(self.model_file)
        model_file_name = os.path.basename(self.model_file)
        new_model_file_name = '_'.join([new_bs, model_file_name])
        new_model_file = '/'.join([model_path, new_model_file_name])
        return new_model_file
        
    def gen_model_with_new_bs(self, new_bs):
        '''return model file with new batch size'''
        new_bs = str(new_bs)
        return self._gen_new_dummy_model(new_bs) if self.dummy_use else self._gen_new_real_model(new_bs)

    def _gen_new_real_model(self, new_bs):
        '''return new model with real data input'''
        new_model_file = self._gen_new_model_file_name(new_bs)
        batch_size_pattern = re.compile("^\s+batch_size:.*")
        bn_stats_batch_size_pattern = re.compile(".*bn_stats_batch_size:.*")
        # we only care about train phase batch size 
        batch_size_cnt = 1
        with open(self.model_file, 'r') as src_f, open(new_model_file, 'w') as dst_f:
            cnt = 0
            for line in src_f.readlines():
                if re.match(batch_size_pattern, line) and cnt < batch_size_cnt:
                    line = re.sub('[0-9]+', new_bs, line, count = 1)
                    cnt += 1
                #consider bn_stats_batch_size field
                if re.match(bn_stats_batch_size_pattern, line):
                    line = re.sub('[0-9]+', new_bs, line, count = 1)
                dst_f.write(line)
        if cnt < batch_size_cnt:
            logging.exception("Error: can't find batch size pattern within the model file {}.".format(self.model_file))
        return new_model_file

    def _gen_new_dummy_model(self, new_bs):
        '''return new model with dummy data input'''
        new_model_file = self._gen_new_model_file_name(new_bs)
        shape_pattern = re.compile(".*shape:.*")
        dim_pattern = re.compile(".*dim:.*")
        # we only care about train phase batch size 
        shape_cnt = 2
        dim_cnt = 0
        with open(self.model_file, 'r') as src_f, open(new_model_file, 'w') as dst_f:
            cnt = 0
            for line in src_f.readlines():
                if re.match(shape_pattern, line):
                    cnt += 1
                    dim_cnt = 0
                if re.match(dim_pattern, line) and cnt <= shape_cnt and dim_cnt == 0:
                    line = re.sub('[0-9]+', new_bs, line, count = 1)
                    dim_cnt += 1
                dst_f.write(line)
        if cnt < shape_cnt:
            logging.exception("Error: can't find batch size pattern within the model file {}.".format(self.model_file))
        return new_model_file
    
    def calculate_fps(self, train_log_file, bs):
        '''calculate fps on train logs generated by model file with new bs'''
        if not os.path.isfile(train_log_file):
            logging.exception("Error: traing log file {} does not exist...".format(traing_log_file))
        average_time = ''
        with open(train_log_file, 'r') as f:
            average_fwd_bwd_time_pattern = re.compile(".*Average Forward-Backward:.*")
            for line in f.readlines():
                if re.match(average_fwd_bwd_time_pattern, line):
                    average_time = line.split()[-2]
                    break
        if average_time == "": 
            logging.exception("Error: can't find average forward-backward time within logs, please check logs under: {}".format(train_log_file))
        average_time = float(average_time)
        speed = float(bs) * 1000.0 / average_time
        logging.info("bs: {}, benchmark speed: {} images/sec".format(str(bs), str(speed)))
        return speed
    
    def exec_command_and_show(self, exec_command):
        '''execute shell command and print it out'''
        def _exec_command_and_iter_show(cmd):
            out = subprocess.Popen(cmd, shell = True, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.PIPE, universal_newlines = True)
            for stdout_line in iter(out.stdout.readline, ''):
                yield stdout_line
            return_code = out.wait()
            if return_code:
                raise subprocess.CalledProcessError(return_code, cmd)
        for line in _exec_command_and_iter_show(exec_command):
            print line
    
    def obtain_fps(self, bs):
        '''obtain fps on model_file with new batch size'''
        caffe_root = os.path.dirname(os.path.dirname(__file__))
        caffe_tool = caffe_root + 'build/tools/caffe'
        mode = 'time'
        current_time = time.strftime("%Y%m%d%H%M%S")
        train_log_file = 'result-train-bs{}-{}.log'.format(str(bs), current_time)
        new_model_file = self.gen_model_with_new_bs(bs)
        exec_command = ' '.join([caffe_tool, mode, '--model', new_model_file, '--iterations', str(self.iterations)])
        exec_command += ' 2>&1 | tee {}'.format(train_log_file)
        self.exec_command_and_show(exec_command)
        fps = self.calculate_fps(train_log_file, bs)
        os.remove(train_log_file)
        os.remove(new_model_file)
        return fps
    
    def find_optimal_bs(self):
        '''find optimal bs for model file and return it'''
        optimal_bs, max_fps = self.min_bs, self.obtain_fps(self.min_bs)
        for bs in xrange(self.min_bs + 1, self.max_bs + 1):
            fps = self.obtain_fps(bs)
            if fps > max_fps:
                optimal_bs, max_fps = bs, fps
        return optimal_bs

def parse_args():
    '''parse arguments'''
    description = 'Used to obtain optimal batch size for specified model file.'
    arg_parser = argparse.ArgumentParser(description = description)
    arg_parser.add_argument('--model_file', help = 'model prototxt file you specified')
    arg_parser.add_argument('--min_bs', type = int, help = 'min batch size you expect')
    arg_parser.add_argument('--max_bs', type = int, help = 'max batch size you expect')
    arg_parser.add_argument('--iterations', type = int, default = 100, help = 'how many iterations to run for caffe time')
    arg_parser.add_argument('--dummy_data_use', action = 'store_true',  help = "dummy_data_use, if set 'True', will use dummy data as input; else if set 'False', will use other data source like lmdb for neural network as input")
    args = arg_parser.parse_args()
    return args

def main():
    '''main routine'''
    main_args = parse_args()
    optimal_bs_finder = OptimalBS(main_args)
    optimal_bs = optimal_bs_finder.find_optimal_bs()
    logging.info('optimal batch size is: {}'.format(optimal_bs))

if __name__ == '__main__':
    main()
